{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# 导入数据\n",
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
    "\n",
    "# 规范化数据\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0\n",
    "\n",
    "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
    "\n",
    "# 定义模型\n",
    "model = keras.Sequential([\n",
    "    keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    keras.layers.Dense(128, activation=tf.nn.relu),\n",
    "    keras.layers.Dense(10, activation=tf.nn.softmax),\n",
    "])\n",
    "\n",
    "model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "\n",
    "# 训练模型\n",
    "model.fit(train_images, train_labels, epochs=5, verbose=2)\n",
    "test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)\n",
    "predictions = model.predict(test_images)\n",
    "print('\\nTest accuracy:', test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_image(i, predictions_array, true_label, img):\n",
    "    predictions_array, true_label, img = predictions_array, true_label[i], img[i]\n",
    "    plt.grid(False)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "\n",
    "    plt.imshow(img, cmap=plt.cm.binary)\n",
    "\n",
    "    predicted_label = np.argmax(predictions_array)\n",
    "    if predicted_label == true_label:\n",
    "        color = 'green'\n",
    "    else:\n",
    "        color = 'red'\n",
    "\n",
    "    plt.xlabel(\"{} {:2.0f}% ({})\".format(\n",
    "        class_names[predicted_label], \n",
    "        100 * np.max(predictions_array),\n",
    "        class_names[true_label]),\n",
    "        color=color\n",
    "    )\n",
    "\n",
    "\n",
    "def plot_value_array(i, predictions_array, true_label):\n",
    "    predictions_array, true_label = predictions_array, true_label[i]\n",
    "    plt.grid(False)\n",
    "    plt.xticks(range(10))\n",
    "    plt.yticks([])\n",
    "    thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
    "    plt.ylim([0, 1])\n",
    "    predicted_label = np.argmax(predictions_array)\n",
    "\n",
    "    thisplot[predicted_label].set_color('red')\n",
    "    thisplot[true_label].set_color('green')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 1000\n",
    "plt.figure(figsize=(6, 3))\n",
    "plt.subplot(1, 2, 1)\n",
    "plot_image(i, predictions[i], test_labels, test_images)\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plot_value_array(i, predictions[i], test_labels)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# num_rows = 5\n",
    "# num_cols = 3\n",
    "# num_images = num_rows * num_cols\n",
    "# plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))\n",
    "# for i in range(num_images):\n",
    "#     plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)\n",
    "#     plot_image(i, predictions[i], test_labels, test_images)\n",
    "#     plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)\n",
    "#     plot_value_array(i, predictions[i], test_labels)\n",
    "# plt.tight_layout()\n",
    "# plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
